#!/bin/bash

# FLUX训练脚本 - 直接传参
# 使用方法: ./train_flux.sh

# 基础路径定义
WORKSPACE="/workspace"
AI_DIR="$WORKSPACE/AI"
PYTHON_PATH="$AI_DIR/.train/bin/python"
SCRIPT_PATH="$AI_DIR/ComfyUI/custom_nodes/comfyui_lora_train/sd_scripts/flux_train_network.py"

# 模型路径定义
MODELS_DIR="$AI_DIR/ComfyUI/models"
FLUX_MODEL="$MODELS_DIR/checkpoints/flux1-schnell.safetensors"
CLIP_L_MODEL="$MODELS_DIR/clip/clip_l.safetensors"
T5XXL_MODEL="$MODELS_DIR/text_encoders/t5xxl_fp16.safetensors"
AE_MODEL="$MODELS_DIR/vae/ae.safetensors"

# 数据集和输出路径
TRAIN_DATA_DIR="$WORKSPACE/training_data/images"
OUTPUT_DIR="$WORKSPACE/training_data/output"

# 16GB显存优化参数
$PYTHON_PATH $SCRIPT_PATH \
    --pretrained_model_name_or_path "$FLUX_MODEL" \
    --clip_l "$CLIP_L_MODEL" \
    --t5xxl "$T5XXL_MODEL" \
    --ae "$AE_MODEL" \
    --cache_latents_to_disk \
    --save_model_as safetensors \
    --sdpa \
    --persistent_data_loader_workers \
    --max_data_loader_n_workers 1 \
    --seed 42 \
    --gradient_checkpointing \
    --mixed_precision bf16 \
    --save_precision bf16 \
    --network_module networks.lora_flux \
    --network_dim 32 \
    --network_train_unet_only \
    --optimizer_type AdaFactor \
    --learning_rate 1e-4 \
    --cache_text_encoder_outputs \
    --cache_text_encoder_outputs_to_disk \
    --highvram \
    --max_train_steps 200 \
    --save_every_n_epochs 1 \
    --train_data_dir "$TRAIN_DATA_DIR" \
    --output_dir "$OUTPUT_DIR" \
    --output_name flux_lora \
    --timestep_sampling shift \
    --discrete_flow_shift 3.1582 \
    --model_prediction_type raw \
    --guidance_scale 1.0 \
    --blocks_to_swap 8 \
    --disable_mmap_load_safetensors \
    --optimizer_args "relative_step=False" "scale_parameter=False" "warmup_init=False" \
    --lr_scheduler constant_with_warmup \
    --max_grad_norm 0.0 \
    --resolution 384 \
    --enable_bucket \
    --flip_aug \
    --caption_extension .txt 